from callbacks import GiveModelToEnvCallback, FixPolicyActionsCallback, CustomCheckpointCallback, CustomEvalCallback
from stable_baselines3.common.callbacks import CallbackList
from wandb.integration.sb3 import WandbCallback
import wandb

from env_setups import get_standard_matrix_env, get_simple_Bayesian_env, get_mspm_env, get_matrix_design_env
from rl_trainer_setup import get_custom_training_algorithm

import os, copy


def get_latest_checkpoint(folder_path):
    max_steps = -1
    latest_checkpoint = None
    for file in os.listdir(folder_path):
        if file.split('.')[1] == 'zip' and int(file.split('_')[2]) > max_steps:
            max_steps = int(file.split('_')[2])
            latest_checkpoint = folder_path+"/"+file.split('.')[0]
    return latest_checkpoint


def train_run(config_dict):

    tot_num_reward_steps = config_dict['tot_num_reward_steps']
    tot_num_eq_steps = config_dict['tot_num_eq_steps']
    frac_excluded_eq_steps = config_dict['frac_excluded_eq_steps']

    exp_name = "exp.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s.%s" % (
        config_dict['experiment_type'],
        config_dict['matrix_game_name'],
        config_dict['max_steps'],
        config_dict['algorithm'],
        config_dict['seed'],
        config_dict['tot_num_reward_steps'],
        config_dict['tot_num_eq_steps'],
        config_dict['frac_excluded_eq_steps'],
        config_dict['critic_obs'],
        config_dict['fix_episode_actions'],
        config_dict["randomized"],
        config_dict["randomization_type"],
        config_dict["followers_algorithm"],
        config_dict["num_followers_messages"],
    )

    log_folder = os.path.join(
        os.path.abspath(os.path.dirname(__file__)) + "/logs/", exp_name
    )

    wandb.tensorboard.patch(root_logdir=log_folder, pytorch=True)
    wandb.init(project="StackMDP")
    wandb.config.setdefaults(config_dict)

    config_dict["log_folder"] = log_folder

    if config_dict["experiment_type"]=="simple_Bayesian":
        env = get_simple_Bayesian_env(config_dict)
        eval_env = get_simple_Bayesian_env(config_dict)
    elif config_dict["experiment_type"]=="matrix_design":
        env = get_matrix_design_env(config_dict)
        eval_env = get_matrix_design_env(config_dict)
    elif config_dict["experiment_type"]=="mspm":
        env = get_mspm_env(config_dict)
        eval_env = get_mspm_env(config_dict)
    else:
        env = get_standard_matrix_env(config_dict)
        eval_env = get_standard_matrix_env(config_dict)

    giveModelToEnvCallback = GiveModelToEnvCallback()

    fixPolicyActionsCallback = FixPolicyActionsCallback()

    checkpointCallback = CustomCheckpointCallback(save_freq=10000000, save_path=log_folder)

    eval_env.is_eval = True
    customEvalCallback = CustomEvalCallback(eval_env, eval_freq = int(config_dict['max_steps']/1000))

    wandb_callback = WandbCallback()

    mod = get_custom_training_algorithm(config_dict['algorithm'], env, n_steps=tot_num_eq_steps - int(
        frac_excluded_eq_steps * tot_num_eq_steps) + tot_num_reward_steps, tensorboard_folder=os.path.join(log_folder))

    callback_list = [wandb_callback, giveModelToEnvCallback, checkpointCallback, customEvalCallback]

    if config_dict['fix_episode_actions']=="True":
        callback_list.append(fixPolicyActionsCallback)

    mod.learn(
        total_timesteps=config_dict['max_steps'],
        callback=CallbackList(callback_list),
    )

    wandb.finish()